1   package org.apache.lucene.search.join;
2   
3   /*
4    * Licensed to the Apache Software Foundation (ASF) under one or more
5    * contributor license agreements.  See the NOTICE file distributed with
6    * this work for additional information regarding copyright ownership.
7    * The ASF licenses this file to You under the Apache License, Version 2.0
8    * (the "License"); you may not use this file except in compliance with
9    * the License.  You may obtain a copy of the License at
10   *
11   *     http://www.apache.org/licenses/LICENSE-2.0
12   *
13   * Unless required by applicable law or agreed to in writing, software
14   * distributed under the License is distributed on an "AS IS" BASIS,
15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16   * See the License for the specific language governing permissions and
17   * limitations under the License.
18   */
19  
20  import java.io.IOException;
21  import java.util.Arrays;
22  import java.util.HashMap;
23  import java.util.LinkedList;
24  import java.util.Map;
25  import java.util.Queue;
26  
27  import org.apache.lucene.index.IndexWriter;
28  import org.apache.lucene.index.LeafReaderContext;
29  import org.apache.lucene.search.Collector;
30  import org.apache.lucene.search.FieldComparator;
31  import org.apache.lucene.search.FieldValueHitQueue;
32  import org.apache.lucene.search.LeafCollector;
33  import org.apache.lucene.search.LeafFieldComparator;
34  import org.apache.lucene.search.Query;
35  import org.apache.lucene.search.ScoreCachingWrappingScorer;
36  import org.apache.lucene.search.Scorer;
37  import org.apache.lucene.search.Scorer.ChildScorer;
38  import org.apache.lucene.search.Sort;
39  import org.apache.lucene.search.TopDocs;
40  import org.apache.lucene.search.TopDocsCollector;
41  import org.apache.lucene.search.TopFieldCollector;
42  import org.apache.lucene.search.TopScoreDocCollector;
43  import org.apache.lucene.search.grouping.GroupDocs;
44  import org.apache.lucene.search.grouping.TopGroups;
45  import org.apache.lucene.util.ArrayUtil;
46  
47  
48  /** Collects parent document hits for a Query containing one more more
49   *  BlockJoinQuery clauses, sorted by the
50   *  specified parent Sort.  Note that this cannot perform
51   *  arbitrary joins; rather, it requires that all joined
52   *  documents are indexed as a doc block (using {@link
53   *  IndexWriter#addDocuments} or {@link
54   *  IndexWriter#updateDocuments}).  Ie, the join is computed
55   *  at index time.
56   *
57   *  <p>This collector MUST be used with {@link ToParentBlockJoinIndexSearcher},
58   *  in order to work correctly.
59   *
60   *  <p>The parent Sort must only use
61   *  fields from the parent documents; sorting by field in
62   *  the child documents is not supported.</p>
63   *
64   *  <p>You should only use this
65   *  collector if one or more of the clauses in the query is
66   *  a {@link ToParentBlockJoinQuery}.  This collector will find those query
67   *  clauses and record the matching child documents for the
68   *  top scoring parent documents.</p>
69   *
70   *  <p>Multiple joins (star join) and nested joins and a mix
71   *  of the two are allowed, as long as in all cases the
72   *  documents corresponding to a single row of each joined
73   *  parent table were indexed as a doc block.</p>
74   *
75   *  <p>For the simple star join you can retrieve the
76   *  {@link TopGroups} instance containing each {@link ToParentBlockJoinQuery}'s
77   *  matching child documents for the top parent groups,
78   *  using {@link #getTopGroups}.  Ie,
79   *  a single query, which will contain two or more
80   *  {@link ToParentBlockJoinQuery}'s as clauses representing the star join,
81   *  can then retrieve two or more {@link TopGroups} instances.</p>
82   *
83   *  <p>For nested joins, the query will run correctly (ie,
84   *  match the right parent and child documents), however,
85   *  because TopGroups is currently unable to support nesting
86   *  (each group is not able to hold another TopGroups), you
87   *  are only able to retrieve the TopGroups of the first
88   *  join.  The TopGroups of the nested joins will not be
89   *  correct.
90   *
91   *  See {@link org.apache.lucene.search.join} for a code
92   *  sample.
93   *
94   * @lucene.experimental
95   */
96  public class ToParentBlockJoinCollector implements Collector {
97  
98    private final Sort sort;
99  
100   // Maps each BlockJoinQuery instance to its "slot" in
101   // joinScorers and in OneGroup's cached doc/scores/count:
102   private final Map<Query,Integer> joinQueryID = new HashMap<>();
103   private final int numParentHits;
104   private final FieldValueHitQueue<OneGroup> queue;
105   private final FieldComparator<?>[] comparators;
106   private final boolean trackMaxScore;
107   private final boolean trackScores;
108 
109   private ToParentBlockJoinQuery.BlockJoinScorer[] joinScorers = new ToParentBlockJoinQuery.BlockJoinScorer[0];
110   private boolean queueFull;
111 
112   private OneGroup bottom;
113   private int totalHitCount;
114   private float maxScore = Float.NaN;
115 
116   /**  Creates a ToParentBlockJoinCollector.  The provided sort must
117    *  not be null.  If you pass true trackScores, all
118    *  ToParentBlockQuery instances must not use
119    *  ScoreMode.None. */
120   public ToParentBlockJoinCollector(Sort sort, int numParentHits, boolean trackScores, boolean trackMaxScore) throws IOException {
121     // TODO: allow null sort to be specialized to relevance
122     // only collector
123     this.sort = sort;
124     this.trackMaxScore = trackMaxScore;
125     if (trackMaxScore) {
126       maxScore = Float.MIN_VALUE;
127     }
128     //System.out.println("numParentHits=" + numParentHits);
129     this.trackScores = trackScores;
130     this.numParentHits = numParentHits;
131     queue = FieldValueHitQueue.create(sort.getSort(), numParentHits);
132     comparators = queue.getComparators();
133   }
134   
135   private static final class OneGroup extends FieldValueHitQueue.Entry {
136     public OneGroup(int comparatorSlot, int parentDoc, float parentScore, int numJoins, boolean doScores) {
137       super(comparatorSlot, parentDoc, parentScore);
138       //System.out.println("make OneGroup parentDoc=" + parentDoc);
139       docs = new int[numJoins][];
140       for(int joinID=0;joinID<numJoins;joinID++) {
141         docs[joinID] = new int[5];
142       }
143       if (doScores) {
144         scores = new float[numJoins][];
145         for(int joinID=0;joinID<numJoins;joinID++) {
146           scores[joinID] = new float[5];
147         }
148       }
149       counts = new int[numJoins];
150     }
151     LeafReaderContext readerContext;
152     int[][] docs;
153     float[][] scores;
154     int[] counts;
155   }
156 
157   @Override
158   public LeafCollector getLeafCollector(final LeafReaderContext context)
159       throws IOException {
160     final LeafFieldComparator[] comparators = queue.getComparators(context);
161     final int[] reverseMul = queue.getReverseMul();
162     final int docBase = context.docBase;
163     return new LeafCollector() {
164 
165       private Scorer scorer;
166 
167       @Override
168       public void setScorer(Scorer scorer) throws IOException {
169         //System.out.println("C.setScorer scorer=" + scorer);
170         // Since we invoke .score(), and the comparators likely
171         // do as well, cache it so it's only "really" computed
172         // once:
173         if (scorer instanceof ScoreCachingWrappingScorer == false) {
174           scorer = new ScoreCachingWrappingScorer(scorer);
175         }
176         this.scorer = scorer;
177         for (LeafFieldComparator comparator : comparators) {
178           comparator.setScorer(scorer);
179         }
180         Arrays.fill(joinScorers, null);
181 
182         Queue<Scorer> queue = new LinkedList<>();
183         //System.out.println("\nqueue: add top scorer=" + scorer);
184         queue.add(scorer);
185         while ((scorer = queue.poll()) != null) {
186           //System.out.println("  poll: " + scorer + "; " + scorer.getWeight().getQuery());
187           if (scorer instanceof ToParentBlockJoinQuery.BlockJoinScorer) {
188             enroll((ToParentBlockJoinQuery) scorer.getWeight().getQuery(), (ToParentBlockJoinQuery.BlockJoinScorer) scorer);
189           }
190 
191           for (ChildScorer sub : scorer.getChildren()) {
192             //System.out.println("  add sub: " + sub.child + "; " + sub.child.getWeight().getQuery());
193             queue.add(sub.child);
194           }
195         }
196       }
197       
198       @Override
199       public void collect(int parentDoc) throws IOException {
200       //System.out.println("\nC parentDoc=" + parentDoc);
201         totalHitCount++;
202 
203         float score = Float.NaN;
204 
205         if (trackMaxScore) {
206           score = scorer.score();
207           maxScore = Math.max(maxScore, score);
208         }
209 
210         // TODO: we could sweep all joinScorers here and
211         // aggregate total child hit count, so we can fill this
212         // in getTopGroups (we wire it to 0 now)
213 
214         if (queueFull) {
215           //System.out.println("  queueFull");
216           // Fastmatch: return if this hit is not competitive
217           int c = 0;
218           for (int i = 0; i < comparators.length; ++i) {
219             c = reverseMul[i] * comparators[i].compareBottom(parentDoc);
220             if (c != 0) {
221               break;
222             }
223           }
224           if (c <= 0) { // in case of equality, this hit is not competitive as docs are visited in order
225             // Definitely not competitive.
226             //System.out.println("    skip");
227             return;
228           }
229 
230           //System.out.println("    competes!  doc=" + (docBase + parentDoc));
231 
232           // This hit is competitive - replace bottom element in queue & adjustTop
233           for (LeafFieldComparator comparator : comparators) {
234             comparator.copy(bottom.slot, parentDoc);
235           }
236           if (!trackMaxScore && trackScores) {
237             score = scorer.score();
238           }
239           bottom.doc = docBase + parentDoc;
240           bottom.readerContext = context;
241           bottom.score = score;
242           copyGroups(bottom);
243           bottom = queue.updateTop();
244 
245           for (LeafFieldComparator comparator : comparators) {
246             comparator.setBottom(bottom.slot);
247           }
248         } else {
249           // Startup transient: queue is not yet full:
250           final int comparatorSlot = totalHitCount - 1;
251 
252           // Copy hit into queue
253           for (LeafFieldComparator comparator : comparators) {
254             comparator.copy(comparatorSlot, parentDoc);
255           }
256           //System.out.println("  startup: new OG doc=" + (docBase+parentDoc));
257           if (!trackMaxScore && trackScores) {
258             score = scorer.score();
259           }
260           final OneGroup og = new OneGroup(comparatorSlot, docBase+parentDoc, score, joinScorers.length, trackScores);
261           og.readerContext = context;
262           copyGroups(og);
263           bottom = queue.add(og);
264           queueFull = totalHitCount == numParentHits;
265           if (queueFull) {
266             // End of startup transient: queue just filled up:
267             for (LeafFieldComparator comparator : comparators) {
268               comparator.setBottom(bottom.slot);
269             }
270           }
271         }
272       }
273       
274       // Pulls out child doc and scores for all join queries:
275       private void copyGroups(OneGroup og) {
276         // While rare, it's possible top arrays could be too
277         // short if join query had null scorer on first
278         // segment(s) but then became non-null on later segments
279         final int numSubScorers = joinScorers.length;
280         if (og.docs.length < numSubScorers) {
281           // While rare, this could happen if join query had
282           // null scorer on first segment(s) but then became
283           // non-null on later segments
284           og.docs = ArrayUtil.grow(og.docs);
285         }
286         if (og.counts.length < numSubScorers) {
287           og.counts = ArrayUtil.grow(og.counts);
288         }
289         if (trackScores && og.scores.length < numSubScorers) {
290           og.scores = ArrayUtil.grow(og.scores);
291         }
292 
293         //System.out.println("\ncopyGroups parentDoc=" + og.doc);
294         for(int scorerIDX = 0;scorerIDX < numSubScorers;scorerIDX++) {
295           final ToParentBlockJoinQuery.BlockJoinScorer joinScorer = joinScorers[scorerIDX];
296           //System.out.println("  scorer=" + joinScorer);
297           if (joinScorer != null && docBase + joinScorer.getParentDoc() == og.doc) {
298             og.counts[scorerIDX] = joinScorer.getChildCount();
299             //System.out.println("    count=" + og.counts[scorerIDX]);
300             og.docs[scorerIDX] = joinScorer.swapChildDocs(og.docs[scorerIDX]);
301             assert og.docs[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.docs[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
302             //System.out.println("    len=" + og.docs[scorerIDX].length);
303             /*
304               for(int idx=0;idx<og.counts[scorerIDX];idx++) {
305               System.out.println("    docs[" + idx + "]=" + og.docs[scorerIDX][idx]);
306               }
307             */
308             if (trackScores) {
309               //System.out.println("    copy scores");
310               og.scores[scorerIDX] = joinScorer.swapChildScores(og.scores[scorerIDX]);
311               assert og.scores[scorerIDX].length >= og.counts[scorerIDX]: "length=" + og.scores[scorerIDX].length + " vs count=" + og.counts[scorerIDX];
312             }
313           } else {
314             og.counts[scorerIDX] = 0;
315           }
316         }
317       }
318     };
319   }
320 
321   private void enroll(ToParentBlockJoinQuery query, ToParentBlockJoinQuery.BlockJoinScorer scorer) {
322     scorer.trackPendingChildHits();
323     final Integer slot = joinQueryID.get(query);
324     if (slot == null) {
325       joinQueryID.put(query, joinScorers.length);
326       //System.out.println("found JQ: " + query + " slot=" + joinScorers.length);
327       final ToParentBlockJoinQuery.BlockJoinScorer[] newArray = new ToParentBlockJoinQuery.BlockJoinScorer[1+joinScorers.length];
328       System.arraycopy(joinScorers, 0, newArray, 0, joinScorers.length);
329       joinScorers = newArray;
330       joinScorers[joinScorers.length-1] = scorer;
331     } else {
332       joinScorers[slot] = scorer;
333     }
334   }
335 
336   private OneGroup[] sortedGroups;
337 
338   private void sortQueue() {
339     sortedGroups = new OneGroup[queue.size()];
340     for(int downTo=queue.size()-1;downTo>=0;downTo--) {
341       sortedGroups[downTo] = queue.pop();
342     }
343   }
344 
345   /** Returns the TopGroups for the specified
346    *  BlockJoinQuery. The groupValue of each GroupDocs will
347    *  be the parent docID for that group.
348    *  The number of documents within each group is calculated as minimum of <code>maxDocsPerGroup</code>
349    *  and number of matched child documents for that group.
350    *  Returns null if no groups matched.
351    *
352    * @param query Search query
353    * @param withinGroupSort Sort criteria within groups
354    * @param offset Parent docs offset
355    * @param maxDocsPerGroup Upper bound of documents per group number
356    * @param withinGroupOffset Offset within each group of child docs
357    * @param fillSortFields Specifies whether to add sort fields or not
358    * @return TopGroups for specified query
359    * @throws IOException if there is a low-level I/O error
360    */
361   public TopGroups<Integer> getTopGroups(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
362                                          int maxDocsPerGroup, int withinGroupOffset, boolean fillSortFields)
363     throws IOException {
364 
365     final Integer _slot = joinQueryID.get(query);
366     if (_slot == null && totalHitCount == 0) {
367       return null;
368     }
369 
370     if (sortedGroups == null) {
371       if (offset >= queue.size()) {
372         return null;
373       }
374       sortQueue();
375     } else if (offset > sortedGroups.length) {
376       return null;
377     }
378 
379     return accumulateGroups(_slot == null ? -1 : _slot.intValue(), offset, maxDocsPerGroup, withinGroupOffset, withinGroupSort, fillSortFields);
380   }
381 
382   /**
383    *  Accumulates groups for the BlockJoinQuery specified by its slot.
384    *
385    * @param slot Search query's slot
386    * @param offset Parent docs offset
387    * @param maxDocsPerGroup Upper bound of documents per group number
388    * @param withinGroupOffset Offset within each group of child docs
389    * @param withinGroupSort Sort criteria within groups
390    * @param fillSortFields Specifies whether to add sort fields or not
391    * @return TopGroups for the query specified by slot
392    * @throws IOException if there is a low-level I/O error
393    */
394   @SuppressWarnings({"unchecked","rawtypes"})
395   private TopGroups<Integer> accumulateGroups(int slot, int offset, int maxDocsPerGroup,
396                                               int withinGroupOffset, Sort withinGroupSort, boolean fillSortFields) throws IOException {
397     final GroupDocs<Integer>[] groups = new GroupDocs[sortedGroups.length - offset];
398     final FakeScorer fakeScorer = new FakeScorer();
399 
400     int totalGroupedHitCount = 0;
401     //System.out.println("slot=" + slot);
402 
403     for(int groupIDX=offset;groupIDX<sortedGroups.length;groupIDX++) {
404       final OneGroup og = sortedGroups[groupIDX];
405       final int numChildDocs;
406       if (slot == -1 || slot >= og.counts.length) {
407         numChildDocs = 0;
408       } else {
409         numChildDocs = og.counts[slot];
410       }
411 
412       // Number of documents in group should be bounded to prevent redundant memory allocation
413       final int numDocsInGroup = Math.max(1, Math.min(numChildDocs, maxDocsPerGroup));
414       //System.out.println("parent doc=" + og.doc + " numChildDocs=" + numChildDocs + " maxDocsPG=" + maxDocsPerGroup);
415 
416       // At this point we hold all docs w/ in each group,
417       // unsorted; we now sort them:
418       final TopDocsCollector<?> collector;
419       if (withinGroupSort == null) {
420         //System.out.println("sort by score");
421         // Sort by score
422         if (!trackScores) {
423           throw new IllegalArgumentException("cannot sort by relevance within group: trackScores=false");
424         }
425         collector = TopScoreDocCollector.create(numDocsInGroup);
426       } else {
427         // Sort by fields
428         collector = TopFieldCollector.create(withinGroupSort, numDocsInGroup, fillSortFields, trackScores, trackMaxScore);
429       }
430 
431       LeafCollector leafCollector = collector.getLeafCollector(og.readerContext);
432       leafCollector.setScorer(fakeScorer);
433       for(int docIDX=0;docIDX<numChildDocs;docIDX++) {
434         //System.out.println("docIDX=" + docIDX + " vs " + og.docs[slot].length);
435         final int doc = og.docs[slot][docIDX];
436         fakeScorer.doc = doc;
437         if (trackScores) {
438           fakeScorer.score = og.scores[slot][docIDX];
439         }
440         leafCollector.collect(doc);
441       }
442       totalGroupedHitCount += numChildDocs;
443 
444       final Object[] groupSortValues;
445 
446       if (fillSortFields) {
447         groupSortValues = new Object[comparators.length];
448         for(int sortFieldIDX=0;sortFieldIDX<comparators.length;sortFieldIDX++) {
449           groupSortValues[sortFieldIDX] = comparators[sortFieldIDX].value(og.slot);
450         }
451       } else {
452         groupSortValues = null;
453       }
454 
455       final TopDocs topDocs = collector.topDocs(withinGroupOffset, numDocsInGroup);
456 
457       groups[groupIDX-offset] = new GroupDocs<>(og.score,
458                                                        topDocs.getMaxScore(),
459                                                        numChildDocs,
460                                                        topDocs.scoreDocs,
461                                                        og.doc,
462                                                        groupSortValues);
463     }
464 
465     return new TopGroups<>(new TopGroups<>(sort.getSort(),
466                                                        withinGroupSort == null ? null : withinGroupSort.getSort(),
467                                                        0, totalGroupedHitCount, groups, maxScore),
468                                   totalHitCount);
469   }
470 
471   /** Returns the TopGroups for the specified BlockJoinQuery.
472    *  The groupValue of each GroupDocs will be the parent docID for that group.
473    *  The number of documents within each group
474    *  equals to the total number of matched child documents for that group.
475    *  Returns null if no groups matched.
476    *
477    * @param query Search query
478    * @param withinGroupSort Sort criteria within groups
479    * @param offset Parent docs offset
480    * @param withinGroupOffset Offset within each group of child docs
481    * @param fillSortFields Specifies whether to add sort fields or not
482    * @return TopGroups for specified query
483    * @throws IOException if there is a low-level I/O error
484    */
485   public TopGroups<Integer> getTopGroupsWithAllChildDocs(ToParentBlockJoinQuery query, Sort withinGroupSort, int offset,
486                                                          int withinGroupOffset, boolean fillSortFields)
487     throws IOException {
488 
489     return getTopGroups(query, withinGroupSort, offset, Integer.MAX_VALUE, withinGroupOffset, fillSortFields);
490   }
491   
492   /**
493    * Returns the highest score across all collected parent hits, as long as
494    * <code>trackMaxScores=true</code> was passed
495    * {@link #ToParentBlockJoinCollector(Sort, int, boolean, boolean) on
496    * construction}. Else, this returns <code>Float.NaN</code>
497    */
498   public float getMaxScore() {
499     return maxScore;
500   }
501 
502   @Override
503   public boolean needsScores() {
504     // needed so that eg. BooleanQuery does not rewrite its MUST clauses to
505     // FILTER since the filter scorers are hidden in Scorer.getChildren().
506     return true;
507   }
508 }